{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1BtkMGSYQOTQ"
      },
      "source": [
        "# Train a gesture recognition model for microcontroller use"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BaFfr7DHRmGF"
      },
      "source": [
        "This notebook demonstrates how to train a 20kb gesture recognition model for [TensorFlow Lite for Microcontrollers](https://tensorflow.org/lite/microcontrollers/overview). It will produce the same model used in the [magic_wand](https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/magic_wand) example application.\n",
        "\n",
        "The model is designed to be used with [Google Colaboratory](https://colab.research.google.com).\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xXgS6rxyT7Qk"
      },
      "source": [
        "Training is much faster using GPU acceleration. Before you proceed, ensure you are using a GPU runtime by going to **Runtime -\u003e Change runtime type** and selecting **GPU**. Training will take around 5 minutes on a GPU runtime."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "LG6ErX5FRIaV"
      },
      "source": [
        "## Configure dependencies\n",
        "\n",
        "Run the following cell to ensure the correct version of TensorFlow is used."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "STNft9TrfoVh"
      },
      "source": [
        "We'll also clone the TensorFlow repository, which contains the training scripts, and copy them into our workspace."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ygkWw73dRNda"
      },
      "outputs": [],
      "source": [
        "# Clone the repository from GitHub\n",
        "!git clone --depth 1 -q https://github.com/tensorflow/tflite-micro\n",
        "# Copy the training scripts into our workspace\n",
        "!cp -r tflite-micro/tensorflow/lite/micro/examples/magic_wand/train train"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "pXI7R4RehFdU"
      },
      "source": [
        "## Prepare the data\n",
        "\n",
        "Next, we'll download the data and extract it into the expected location within the training scripts' directory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "W2Sg2AKzVr2L"
      },
      "outputs": [],
      "source": [
        "# Download the data we will use to train the model\n",
        "!wget http://download.tensorflow.org/models/tflite/magic_wand/data.tar.gz\n",
        "# Extract the data into the train directory\n",
        "!tar xvzf data.tar.gz -C train 1\u003e/dev/null"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DNjukI1Sgl2C"
      },
      "source": [
        "We'll then run the scripts that split the data into training, validation, and test sets."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "XBqSVpi6Vxss"
      },
      "outputs": [],
      "source": [
        "# The scripts must be run from within the train directory\n",
        "%cd train\n",
        "# Prepare the data\n",
        "!python data_prepare.py\n",
        "# Split the data by person\n",
        "!python data_split_person.py"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5-cmVbFvhTvy"
      },
      "source": [
        "## Load TensorBoard\n",
        "\n",
        "Now, we set up TensorBoard so that we can graph our accuracy and loss as training proceeds."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "CCx6SN9NWRPw"
      },
      "outputs": [],
      "source": [
        "# Load TensorBoard\n",
        "%load_ext tensorboard\n",
        "%tensorboard --logdir logs/scalars"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ERC2Cr4PhaOl"
      },
      "source": [
        "## Begin training\n",
        "\n",
        "The following cell will begin the training process. Training will take around 5 minutes on a GPU runtime. You'll see the metrics in TensorBoard after a few epochs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "DXmQZgbuWQFO"
      },
      "outputs": [],
      "source": [
        "!python train.py --model CNN --person true"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "4gXbVzcXhvGD"
      },
      "source": [
        "## Create a C source file\n",
        "\n",
        "The `train.py` script writes a model, `model.tflite`, to the training scripts' directory.\n",
        "\n",
        "In the following cell, we convert this model into a C++ source file we can use with TensorFlow Lite for Microcontrollers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "8wgei4OGe3Nz"
      },
      "outputs": [],
      "source": [
        "# Install xxd if it is not available\n",
        "!apt-get -qq install xxd\n",
        "# Save the file as a C source file\n",
        "!xxd -i model.tflite \u003e /content/model.cc\n",
        "# Print the source file\n",
        "!cat /content/model.cc"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "Train a gesture recognition model for microcontroller use",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
